import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import pipeline, AutoTokenizer
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score
from sklearn.metrics import f1_score as f1
from sklearn.metrics import matthews_corrcoef as mcc
import warnings
from transformers import logging
logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.options.mode.chained_assignment = None 
llama_key = open('./llama_key.txt', 'r').read()
covid = pd.read_csv('./data/covid_tweets_labeled.csv', encoding = "ISO-8859-1")

############
## Zero Shot
############
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

pipe = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="cuda",
    token = llama_key
)

user_message = """You are a classifier that determines if a tweet is minimizing the threat of COVID-19.

Tweets that minimize the threat of COVID-19 do one or more of the following:
- Question how dangerous the virus is or claim that death totals related to the virus are inflated or fake.
- Make statements that are anti-vaccination or claim that the COVID-19 vaccine is ineffective or dangerous.
- Criticize the use of masks or mask mandates.
- Criticize social distancing guidelines, lockdowns, or stay-at-home orders.
- Insinuate that COVID-19 is a conspiracy or a tool to control people.

I will show you a tweet. Read the tweet and then determine if the tweet is minimizing the threat of COVID-19. Here is the tweet:

{}

If the tweet is minimizing the threat of COVID-19, return 1. If it is not minimizing the threat of COVID-19, return 0. Do not explain your answer, and only return the number."""

data = covid
res = []
for doc in data['text']:
    messages = [
        {"role": "user", "content": user_message.format(doc)},
    ]
    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text = False, pad_token_id=pipe.tokenizer.eos_token_id, temperature = 0)
    res.extend(outputs)

# Extract labels
res = [text['generated_text'] for text in res]
res = [num[0] for num in res]
covid['llama_zs'] = [1 if '1' in text else 0 for text in res]

# Export
covid.to_csv('./data/covid_tweets_labeled.csv', index = False)